# Copyright 2020 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MusicVAE generation script."""

# TODO(adarob): Add support for models with conditioning.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
#exec(open("music_vae_generate.py").read())
import os
import random
import sys
import time

from magenta import music as mm
import midi_io
import configs
import numpy as np
import tensorflow.compat.v1 as tf
import pretty_midi
import pickle
from music21 import *
import shutil
from magenta.models.music_vae import TrainedModel
import torch
import operator

#given dict of (k, p(k)), renormalize so sum of all p(k) == 1
def renormalizeProbs(probdict):
    sumps = sum(probdict.values())
    for k in probdict.keys():
        probdict[k] = (probdict[k] + 0.0) / sumps

#from dict of probabilities to a choice based on the probabilities
def probDictToChoice(pdict):
    renormalizeProbs(pdict)
    ps = {}
    prev = 0
    for (k,v) in filter(lambda i: i[1] > 0.0, pdict.items()):
        #print(k)
        ps[k] = prev  + v
        prev = ps[k]
    r = random.uniform(0.0, 1.0)
    psitems = sorted(ps.items(), key = operator.itemgetter(1))
    for k,v in psitems:
        #print("k " + k)
        if r <= v:
            return k
    return ps.keys()[-1]

flags = tf.app.flags
logging = tf.logging
FLAGS = flags.FLAGS
"""
flags.DEFINE_string(
    'run_dir', "./",
    'Path to the directory where the latest checkpoint will be loaded from.')
flags.DEFINE_string(
    'checkpoint_file', None,
    'Path to the checkpoint file. run_dir will take priority over this flag.')
flags.DEFINE_string(
    'output_dir', 'music_vae/generated',
    'The directory where MIDI files will be saved to.')
flags.DEFINE_string(
    'config', "cat-mel_2bar_big",
    'The name of the config to use.')
flags.DEFINE_string(
    'mode', 'sample',
    'Generate mode (either `sample` or `interpolate`).')
flags.DEFINE_string(
    'input_midi_1', "",
    'Path of start MIDI file for interpolation.')
flags.DEFINE_string(
    'input_midi_2', "",
    'Path of end MIDI file for interpolation.')
flags.DEFINE_integer(
    'num_outputs', 5,
    'In `sample` mode, the number of samples to produce. In `interpolate` '
    'mode, the number of steps (including the endpoints).')
flags.DEFINE_integer(
    'max_batch_size', 8,
    'The maximum batch size to use. Decrease if you are seeing an OOM.')
flags.DEFINE_float(
    'temperature', 0.5,
    'The randomness of the decoding process.')
flags.DEFINE_string(
    'log', 'INFO',
    'The threshold for what messages will be logged: '
    'DEBUG, INFO, WARN, ERROR, or FATAL.')
"""
def _slerp(p0, p1, t):
  """Spherical linear interpolation."""
  omega = np.arccos(
      np.dot(np.squeeze(p0/np.linalg.norm(p0)),
             np.squeeze(p1/np.linalg.norm(p1))))
  so = np.sin(omega)
  return np.sin((1.0-t)*omega) / so * p0 + np.sin(t*omega)/so * p1




config = "cat-mel_2bar_small"
model = TrainedModel(
            configs.CONFIG_MAP[config], batch_size=2,
            checkpoint_dir_or_path="cat-mel_2bar_small/model.ckpt")


reals = pickle.load(open("pickles/recons.pcl", "rb"))

#probs = pickle.load(open("../midiprobs.pcl", "rb"))

#probs = {}
all_meas = []
all_magents = []
for (q_, i) in enumerate(reals):
    meas = []
    index = 0
    for (k, val) in enumerate(i):
        while True:
                index += 1
                z = np.array([val, val])
                results = model.decode(
                length=16,
                z=z,
                temperature=1.0)
                mm.sequence_proto_to_midi_file(results[0], "tmpmids/" + str(q_) + "-" + str(k) + ".mid")
                try:
                    a = converter.parse("tmpmids/" + str(q_) + "-" + str(k) + ".mid")
                    for part in a:
                        for val2 in list(part):
                            if type(val2) == note.Note:
                                meas.append((val2.pitch.midi, val2.quarterLength))
                            elif type(val2) == note.Rest:
                                meas.append((0, val2.quarterLength))
                    all_meas.append(meas)
                    all_magents.append(z[0,:])
                    break
                except:
                    if index > 100:
                        break
         
    a = stream.Score()
    onset = 0.0
    for (pit, dur) in meas:
        if pit > 0:
            a.insert(onset, note.Note(pit, quarterLength = dur))
        else:
            a.insert(onset, note.Rest(quarterLength = dur))
        onset += dur
    a.write(fmt="mid", fp="generatedMids/" + str(q_) + "-vae.mid")
